Mulgrad
计算逐元素乘法 (Mul) 操作的梯度。该算子是 Mul 算子的反向传播(backward pass)部分。梯度的计算遵循链式法则。
\[ \begin{align}\begin{aligned}\text{dx0}_i = \text{dy}_i \times \text{Input1}_i\\\text{dx1}_i = \text{dy}_i \times \text{Input0}_i\end{aligned}\end{align} \]
其中 dx0 和 dx1 分别是损失函数对前向输入 Input0 和 Input1 的梯度。
Gradmul1L版本专门用于 x1 张量维度大于或等于 x2 张量的广播场景。Gradmul2l版本专门用于 x2 张量维度大于或等于 x1 张量的广播场景。
- 输入:
dy - 来自后一层的上游梯度张量。
x1 - 前向传播时的第一个输入张量(被除数)。
x2 - 前向传播时的第二个输入张量(除数)。
large_shape - x1 和 x2 中维度较大的张量的形状。
small_shape - x1 和 x2 中维度较小的张量的形状。
out_shape - 输出张量 dx1 和 dx2 的形状。
ndims - 张量的维度数。
large_strides - 维度较大张量的步长信息。
small_strides - 维度较小张量的步长信息。
out_strides - 输出张量的步长信息。
large_multiples - 维度较大张量的广播倍数。
small_multiples - 维度较小张量的广播倍数。
tile_data0 - 临时工作空间地址。
tile_data1 - 临时工作空间地址。
indices - 用于广播计算的临时索引空间地址。
core_mask - 核掩码。
- 输出:
dx1 - 写入计算出的对 x1 的梯度。
dx2 - 写入计算出的对 x2 的梯度。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持fp32
MT7004 支持fp16, fp32
共享存储版本:
-
void fp_gradmul_s(float *dy, float *x1, float *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, float *dx1, float *dx2, float *tile_data0, float *tile_data1, int *indices, int core_mask)
-
void hp_gradmul_s(half *dy, half *x1, half *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, half *dx1, half *dx2, half *tile_data0, half *tile_data1, int *indices, int core_mask)
-
void fp_gradmul1l_s(float *dy, float *x1, float *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, float *dx1, float *dx2, float *tile_data0, float *tile_data1, int *indices, int core_mask)
-
void hp_gradmul1l_s(half *dy, half *x1, half *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, half *dx1, half *dx2, half *tile_data0, half *tile_data1, int *indices, int core_mask)
-
void fp_gradmul2l_s(float *dy, float *x1, float *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, float *dx1, float *dx2, float *tile_data0, float *tile_data1, int *indices, int core_mask)
-
void hp_gradmul2l_s(half *dy, half *x1, half *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, half *dx1, half *dx2, half *tile_data0, half *tile_data1, int *indices, int core_mask)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <divgrad.h>
4int main(int argc, char* argv[]) {
5 float *dy = (float *)0xA1000000;
6 float *dx1 = (float *)0xA2000000;
7 float *dx2 = (float *)0xA3000000;
8 float *x1_data = (float *)0xA4000000;
9 float *x2_data = (float *)0xA5000000;
10 float *tile_data0 = (float *)0xA6000000;
11 float *tile_data1 = (float *)0xA7000000;
12
13 long long ndims = 4;
14 long long dy_size;
15 long long x1_size;
16 long long x2_size;
17
18 int *large_strides = (int *)0xAB000000;
19 int *small_strides = (int *)0xAB100000;
20 int *out_strides = (int *)0xAB200000;
21 int *large_multiples = (int *)0xAB300000;
22 int *small_multiples = (int *)0xAB400000;
23 int *indices = (int *)0xAB500000;
24 int *large_shape = (int *)0xAB600000;
25 int *small_shape = (int *)0xAB700000;
26 int *out_shape = (int *)0xAB800000;
27
28 large_shape[0] = 12; large_shape[1] = 14; large_shape[2] = 3; large_shape[3] = 5;
29 small_shape[0] = 12; small_shape[1] = 14; small_shape[2] = 3; small_shape[3] = 5;
30 out_shape[0] = 12; out_shape[1] = 14; out_shape[2] = 3; out_shape[3] = 5;
31
32 int core_mask = 0xff;
33
34 dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
35 x1_size = large_shape[0] * large_shape[1] * large_shape[2] * large_shape[3];
36 x2_size = small_shape[0] * small_shape[1] * small_shape[2] * small_shape[3];
37
38 fp_gradmul_s(dy, x1, x2, large_shape, small_shape, out_shape, ndims, large_strides, small_strides, out_strides, large_multiples, small_multiples, dx1, dx2, tile_data0, tile_data1, indices, core_mask);
39 return 0;
40}
私有存储版本:
-
void fp_gradmul_p(float *dy, float *x1, float *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, float *dx1, float *dx2, float *tile_data0, float *tile_data1, int *indices)
-
void hp_gradmul_p(half *dy, half *x1, half *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, half *dx1, half *dx2, half *tile_data0, half *tile_data1)
-
void fp_gradmul1l_p(float *dy, float *x1, float *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, float *dx1, float *dx2, float *tile_data0, float *tile_data1, int *indices)
-
void hp_gradmul1l_p(half *dy, half *x1, half *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, half *dx1, half *dx2, half *tile_data0, half *tile_data1, int *indices)
-
void fp_gradmul2l_p(float *dy, float *x1, float *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, float *dx1, float *dx2, float *tile_data0, float *tile_data1, int *indices)
-
void hp_gradmul2l_p(half *dy, half *x1, half *x2, int *large_shape, int *small_shape, int *out_shape, int ndims, int *large_strides, int *small_strides, int *out_strides, int *large_multiples, int *small_multiples, half *dx1, half *dx2, half *tile_data0, half *tile_data1, int *indices)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <mulgrad.h>
4int main(int argc, char* argv[]) {
5 float *dy = (float *)0x10000000;
6 float *dx1 = (float *)0x12000000;
7 float *dx2 = (float *)0x13000000;
8 float *x1_data = (float *)0x14000000;
9 float *x2_data = (float *)0x15000000;
10 float *tile_data0 = (float *)0x16000000;
11 float *tile_data1 = (float *)0x17000000;
12
13 long long ndims = 4;
14 long long dy_size;
15 long long x1_size;
16 long long x2_size;
17
18 int *large_strides = (int *)0x1B000000;
19 int *small_strides = (int *)0x1B100000;
20 int *out_strides = (int *)0x1B200000;
21 int *large_multiples = (int *)0x1B300000;
22 int *small_multiples = (int *)0x1B400000;
23 int *indices = (int *)0x1B500000;
24 int *large_shape = (int *)0x1B600000;
25 int *small_shape = (int *)0x1B700000;
26 int *out_shape = (int *)0x1B800000;
27
28 large_shape[0] = 12; large_shape[1] = 14; large_shape[2] = 3; large_shape[3] = 5;
29 small_shape[0] = 12; small_shape[1] = 14; small_shape[2] = 3; small_shape[3] = 5;
30 out_shape[0] = 12; out_shape[1] = 14; out_shape[2] = 3; out_shape[3] = 5;
31
32 dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
33 x1_size = large_shape[0] * large_shape[1] * large_shape[2] * large_shape[3];
34 x2_size = small_shape[0] * small_shape[1] * small_shape[2] * small_shape[3];
35
36 fp_gradmul_p(dy, x1, x2, large_shape, small_shape, out_shape, ndims, large_strides, small_strides, out_strides, large_multiples, small_multiples, dx1, dx2, tile_data0, tile_data1, indices);
37 return 0;
38}